import torch
import torch.nn as nn
import torch.nn.functional as F
"""

https://github.com/wy1iu/LargeMargin_Softmax_Loss/blob/master/myexamples/cifar10/model/cifar_train_test.prototxt
"""

"""
-------------------conv0------------------------
conv->bn->relu
layer {
  name: "pre_conv"
  type: "Convolution"
  bottom: "data"
  top: "pre_conv"
  param {
    lr_mult: 1
    decay_mult: 1
  }
  param {
    lr_mult: 2
    decay_mult: 0
  }
  convolution_param {
    num_output: 64
    kernel_size: 3
    stride: 1
    pad: 1
    weight_filler {
      type: "msra"
    }
    bias_filler {
      type: "constant"
    }
  }
}
layer {
  name: "pre_bn"
  type: "BN"
  bottom: "pre_conv"
  top: "pre_bn"
  param {
    lr_mult: 1 
    decay_mult: 0 
  }
  param {
    lr_mult: 1
    decay_mult: 0
  }
  bn_param {
    slope_filler {
      type: "constant"
      value: 1
    }
    bias_filler {
      type: "constant"
      value: 0
    }
  }
}
layer {
  name: "pre_relu"
  type: "ReLU"
  bottom: "pre_bn"
  top: "pre_bn"
}


############################### conv1.x ################################

conv1:( [conv(64,64,3,1),bn,relu],
        [conv(64,64,3,1),bn,relu],
        [conv(64,64,3,1),bn,relu],
        [conv(64,64,3,1),bn,relu])
        maxpool(2,2)
layer {
  name: "conv1_1"
  type: "Convolution"
  bottom: "pre_bn"
  top: "conv1_1"
  param {
    lr_mult: 1
    decay_mult: 1
  }
  param {
    lr_mult: 2
    decay_mult: 0
  }
  convolution_param {
    num_output: 64
    kernel_size: 3
    stride: 1
    pad: 1
    weight_filler {
      type: "msra"
    }
    bias_filler {
      type: "constant"
    }
  }
}
layer {
  name: "bn1_1"
  type: "BN"
  bottom: "conv1_1"
  top: "bn1_1"
  param {
    lr_mult: 1 
    decay_mult: 0 
  }
  param {
    lr_mult: 1
    decay_mult: 0
  }
  bn_param {
    slope_filler {
      type: "constant"
      value: 1
    }
    bias_filler {
      type: "constant"
      value: 0
    }
  }
}
layer {
  name: "relu1_1"
  type: "ReLU"
  bottom: "bn1_1"
  top: "bn1_1"
}
layer {
  name: "conv1_2"
  type: "Convolution"
  bottom: "bn1_1"
  top: "conv1_2"
  param {
    lr_mult: 1
    decay_mult: 1
  }
  param {
    lr_mult: 2
    decay_mult: 0
  }
  convolution_param {
    num_output: 64
    kernel_size: 3
    stride: 1
    pad: 1
    weight_filler {
      type: "msra"
    }
    bias_filler {
      type: "constant"
    }
  }
}
layer {
  name: "bn1_2"
  type: "BN"
  bottom: "conv1_2"
  top: "bn1_2"
  param {
    lr_mult: 1 
    decay_mult: 0 
  }
  param {
    lr_mult: 1
    decay_mult: 0
  }
  bn_param {
    slope_filler {
      type: "constant"
      value: 1
    }
    bias_filler {
      type: "constant"
      value: 0
    }
  }
}
layer {
  name: "relu1_2"
  type: "ReLU"
  bottom: "bn1_2"
  top: "bn1_2"
}
layer {
  name: "conv1_3"
  type: "Convolution"
  bottom: "bn1_2"
  top: "conv1_3"
  param {
    lr_mult: 1
    decay_mult: 1
  }
  param {
    lr_mult: 2
    decay_mult: 0
  }
  convolution_param {
    num_output: 64
    kernel_size: 3
    stride: 1
    pad: 1
    weight_filler {
      type: "msra"
    }
    bias_filler {
      type: "constant"
    }
  }
}
layer {
  name: "bn1_3"
  type: "BN"
  bottom: "conv1_3"
  top: "bn1_3"
  param {
    lr_mult: 1 
    decay_mult: 0 
  }
  param {
    lr_mult: 1
    decay_mult: 0
  }
  bn_param {
    slope_filler {
      type: "constant"
      value: 1
    }
    bias_filler {
      type: "constant"
      value: 0
    }
  }
}
layer {
  name: "relu1_3"
  type: "ReLU"
  bottom: "bn1_3"
  top: "bn1_3"
}
layer {
  name: "conv1_4"
  type: "Convolution"
  bottom: "bn1_3"
  top: "conv1_4"
  param {
    lr_mult: 1
    decay_mult: 1
  }
  param {
    lr_mult: 2
    decay_mult: 0
  }
  convolution_param {
    num_output: 64
    kernel_size: 3
    stride: 1
    pad: 1
    weight_filler {
      type: "msra"
    }
    bias_filler {
      type: "constant"
    }
  }
}
layer {
  name: "bn1_4"
  type: "BN"
  bottom: "conv1_4"
  top: "bn1_4"
  param {
    lr_mult: 1 
    decay_mult: 0 
  }
  param {
    lr_mult: 1
    decay_mult: 0
  }
  bn_param {
    slope_filler {
      type: "constant"
      value: 1
    }
    bias_filler {
      type: "constant"
      value: 0
    }
  }
}
layer {
  name: "relu1_4"
  type: "ReLU"
  bottom: "bn1_4"
  top: "bn1_4"
}
layer {
  name: "pool1"
  type: "Pooling"
  bottom: "bn1_4"
  top: "pool1"
  pooling_param {
    pool: MAX
    kernel_size: 2
    stride: 2
  }
}
"""
class CNN(nn.Module):
    k = 3
    p = (k - 1) // 2

    def __init__(self,out_feature=512, bn=True):
        super().__init__()
        self.conv0 = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU())
        self.conv1 = self._make_layer(64, 64, 4, bn)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = self._make_layer(64, 128, 4, bn)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.conv3 = self._make_layer(128, 256, 4, bn)
        self.pool3 = nn.MaxPool2d(2, 2)

        self.fc = nn.Sequential(
            nn.Linear(256*16, out_feature),
            nn.BatchNorm1d(out_feature),
            nn.ReLU())


    def _make_layer(self, in_channel, out_channel, number, batch_norm=False):
        layers = ([self._conv(in_channel, out_channel, batch_norm), ] +
                  [self._conv(out_channel, out_channel, batch_norm) for _i in range(number - 1)])
        return nn.Sequential(*layers)

    def _conv(self, in_channel, out_channel, batch_norm):
        if batch_norm:
            return nn.Sequential(nn.Conv2d(in_channel, out_channel, self.k, 1, self.p),
                                 nn.BatchNorm2d(out_channel),
                                 nn.PReLU())
        else:
            return nn.Sequential(nn.Conv2d(in_channel, out_channel, self.k, 1, self.p),
                                 nn.PReLU())

    def forward(self, x):
        x = self.conv0(x)
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x = self.pool3(x)
        return self.fc(x.reshape(x.size(0), -1))
